Change tensorflow model format to SavedModel to support sub-classed models#628
Change tensorflow model format to SavedModel to support sub-classed models#628ascillitoe wants to merge 26 commits intoSeldonIO:masterfrom ascillitoe:feature/tf_SavedModel
SavedModel to support sub-classed models#628Conversation
| try: # legacy load_model behaviour was to return None if not found. Now it raises error, hence need try-except. | ||
| model = load_model(filepath, load_dir='encoder') | ||
| except FileNotFoundError: | ||
| except OSError: |
There was a problem hiding this comment.
Changed to OSError because we are now relying on tf.keras.model.load_model to raise an error when loading fails, and this is what it raises...
| elif isinstance(detector, (ChiSquareDrift, ClassifierDrift, KSDrift, MMDDrift, TabularDrift)): | ||
| if model is not None: | ||
| save_model(model, filepath, save_dir='encoder') | ||
| save_model(model, filepath, save_dir='encoder', save_format='h5') |
There was a problem hiding this comment.
Stick to saving in h5 format for legacy saves...
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #628 +/- ##
==========================================
- Coverage 80.35% 80.33% -0.03%
==========================================
Files 137 137
Lines 9300 9304 +4
==========================================
+ Hits 7473 7474 +1
- Misses 1827 1830 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
|
SavedModelSavedModel to support sub-classed models
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
The commits from f61af31 onwards contain three primary changes (based on discussion with @jklaise ):
|
| model that can be saved and loaded with `torch.save(..., pickle_module=dill)` and `torch.load(..., pickle_module=dill)`. | ||
| ```{note} | ||
|
|
||
| - The {obj}`~alibi_detect.cd.tensorflow.HiddenOutput` utility class is not currently compatible with subclassed models. |
There was a problem hiding this comment.
In what sense is it not? How could it be made compatible in the future?
There was a problem hiding this comment.
HiddenOutput works by building a new tf.keras.Model from the original model's input and layers attributes.
alibi-detect/alibi_detect/cd/tensorflow/preprocess.py
Lines 67 to 85 in c0c5e64
This works fine for tf.keras.Model's constructed from tf.keras.Sequential etc as they have a pre-defined structure for these attributes. It doesn't work out-of-the-box for subclassed models, presumably because there are quite a few different ways you can construct these models internally. Probably worth opening an issue to explore this one further...
There was a problem hiding this comment.
I see, feels like a docstring is required for HiddenOutput(and other utility classes) to describe behaivour/limitations. For another PR though.
There was a problem hiding this comment.
I'll open an issue shortly 👍🏻
Edit: #734
| @@ -18,6 +18,7 @@ | |||
|
|
|||
| def load_model(filepath: Union[str, os.PathLike], | |||
There was a problem hiding this comment.
Extra functionality sneaking into this PR... Worth adding changelog entries to this PR so everything is documented and not missed upon release?
There was a problem hiding this comment.
Not 100% clear what you mean here? The functionality of passing kwargs to load_detector? It is extra functionality but is interlinked with the PR, as custom_objects needs to be passed to load_detector.
Edit: reading again, I see what you mean. Since we also pass kwarg's to pytorch. I could factor this out to a separate PR if preferred...
There was a problem hiding this comment.
No need, but would appreciate a changelog as part of the PR.
| from alibi_detect.models.tensorflow.autoencoder import (AE, AEGMM, VAE, VAEGMM, | ||
| DecoderLSTM, | ||
| EncoderLSTM, Seq2Seq) | ||
| from alibi_detect.utils.tensorflow.misc import check_model |
There was a problem hiding this comment.
Minor, but why not keep the function in this module since it's specifically used during loading?
There was a problem hiding this comment.
Mmn I put it here so that its next to clone_model, which is quite related. But can move back to saving since we only use it there...
| # Check model cloning doesn't raise error | ||
| clone_model(model) | ||
|
|
||
| except Exception as error: | ||
| if raise_error: | ||
| raise ValueError(msg) from error | ||
| else: | ||
| warnings.warn(msg + f"Original error message: \n\t{error}") |
There was a problem hiding this comment.
Not sure this works that well together, e.g. we check if cloning works but it any case, the error is never raised. So effectively this would blow up again when using inside detectors that do use cloning? Or is the intention to call this method there also?
There was a problem hiding this comment.
The problem I am having here is that there are a large number of possible failure modes that can occur due to not specifying the custom_objects properly. Depending on the exact tf version, and whether the model itself is subclassed, or just layers, either ValueError, TypeError's or NotImeplemented errors are raised, and this is either during inference or cloning.
The widest net I've been able to cast is to simply test if cloning works, and to check if there might be problems at inference, check whether the model is a RevivedNetwork, which is what it's loaded as if custom objects are missing (I don't try to actually call the model as data isn't available at this point).
A wide net is nice in term of hopefully catching most errors, but has the downside of throwing errors when things might have actually worked. For example, the RevivedNetwork's generally work for inference but not cloning. So a user can actually get away with not passing custom_objects if the model is just to be used for preprocessing. Hence why I went with a warning instead of error...
Maybe one compromise is to raise a warning if the model is a RevivedNetwork, since this can cause issues, but raise the ValueError (with the more coherent message) if cloning fails?
There was a problem hiding this comment.
I'm not sure, cloning failure is only relevant for a specific set of detectors, so somehow feels like it should be checked only for that subset. Maybe leave this as is and wrap the relevant detector calls to clone() in try/except with a customized Alibi error message?
There was a problem hiding this comment.
Sorry, just coming back on this, I've thought about it a bit more. There are just so many different failure modes, and the problem with just checking if cloneable is that there is one nasty failure mode where inference works, cloning works, BUT inference after cloning doesn't work 🤯 (this occurs in example 5 here, the custom call method is lost when cloning...).
This is what motivates casting the wide net by checking if a RevivedNetwork too. But then throwing a warning for this would prevent users getting away with cases that would otherwise work with no custom_objects provided (generally when no cloning).
I see two options:
- Leave pretty much as is. Downside at the moment is by swallowing the
ValueErrorfrom cloning and raising as a warning, it can be somewhat lost in the later actual errors that occur (i.e. when cloning down for real). - Turn the warning into an error. Be up front about requiring
custom_objectsto be passed in all cases where there are custom objects involved. This is slightly less convenient for the user, but means we don't have to worry about these many different failure modes that might also change in future versions.
There was a problem hiding this comment.
Raising an error here would pretty much be equivalent to outright not supporting revived models? Perhaps it's not a bad idea, especially if that functionality is not well documented on tensorflow docs.
There was a problem hiding this comment.
Yes exactly. It would mean any models containing custom classes loaded without providing those custom classes (via registering, or custom_objects kwarg) would be pretty much guaranteed to fail. Thinking about it more this would probably be what I would lean towards. Although to be clear, this does also mean that, for example, if a user had a relatively simple tf.keras.Sequential, but with one custom layer, they would have to provide this layer at load time even if they don't use the model in a detector with any cloning...
There was a problem hiding this comment.
The tensorflow behaviour with this has been so fluid in versions 2.9, 2.10 and 2.11 that I do think it'd be the safest option though... (for now)
| raise FileNotFoundError(f'{model_name} not found in {model_dir.resolve()}.') | ||
| model = tf.keras.models.load_model(model_dir.joinpath(model_name), custom_objects=custom_objects) | ||
| # Load model | ||
| model = tf.keras.models.load_model(filepath, **kwargs) |
There was a problem hiding this comment.
Seems like we're throwing away the validation code for the existence of the model? Or is it done from higher up in another caller?
There was a problem hiding this comment.
Its not actually done from higher up. Rather, I realised that the validation might be superfluous since tf.keras.models.load_model already raises OSError: No file or directory found at test.h5 if a filepath to a .h5 model is passed and one doesn't exist, and OSError: SavedModel file does not exist at: test//{saved_model.pbtxt|saved_model.pb} if an directory is passed.
| def load_model(filepath: Union[str, os.PathLike], | ||
| filename: str = 'model', | ||
| custom_objects: dict = None, | ||
| layer: Optional[int] = None, |
There was a problem hiding this comment.
Minor, but I don't fully agree with omitting filename from these internal save/load functions. What we save for in function signature, we pay at every callsite, having to remember to do .joinpath(filename). (OTOH in the old behaviour, having a default model name is also likely not desirable as forgetting to set it would result in a perhaps unexpected default).
There was a problem hiding this comment.
Pretty much the only reason I made this change is that the filename is only there for legacy loading (legacy as in, saving to .h5, and legacy as in loading files with different names such as encoder.h5). For the "modern" loading we simply do:
if flavour == Framework.TENSORFLOW:
model = load_model_tf(src, layer=layer, **kwargs)So I saw it as a trade-off wrt to carrying around more complexity in load_model to facilitate the legacy functionality in load_detector_legacy, or add some complexity to the calls in load_detector_legacy to simplify the load_model function... (which just happens to be used for modern and legacy loading).
Similar story for save_model...
There was a problem hiding this comment.
Also, noting that there's a few subtle changes in the behaviour of internal functions for saving/loading, need to be extra vigilant new bugs haven't been introduced.
p.s. below is very true though... tweaking anything to do with legacy save/load does bring the potential for bugs like the ones v0.10.5 fixed. I've run the same tests of loading old artefacts we ran in #732 and everything passes, but that isn't 100% comprehensive...
jklaise
left a comment
There was a problem hiding this comment.
Generally LGTM, but would appreciate a changelog with all the changes. Also, noting that there's a few subtle changes in the behaviour of internal functions for saving/loading, need to be extra vigilant new bugs haven't been introduced.
|
Postponing this to |
This PR changes the format we use to serialize TensorFlow models from the old HDF5 to the newer SavedModel format.
Motivation
As well as being the default (and recommended) TensorFlow model format, the
SavedModelformat has the advantage of supporting serialisation of sub-classed tensorflow models. That is, models constructed by subclassingtf.keras.Model, rather than by usingtf.keras.Sequential,tf.keras.models.Modeletc.Limitations
tf.keras.models.load_modelviacustom_objects(or registered with@tf.keras.utils.register_keras_serializable()) the model will be loaded as akeras.saving.saved_model.load.<model_class>*. This is a rough copy of the original serialized model, that behaves the same wrt inference, but cannot be cloned (which is done in a number of learned detectors such asClassifierDrift). To load the fully-functional model, all custom objects must be supplied at load time.layeris specified in the ModelConfig.*in
tensorflow>=2.9. In older versions, loading of the model will fail completely if the custom objects are not provided.Main changes
save_formatfrom'h5'to'tf'insave_modelandload_model, although stuck withh5for the legacy save/load functions.custom_objectsdictionary via config since support for this was very flaky. Custom objects in the dictionary could only realistically be specified as registered object strings ('@mymodeletc). However, this is confusing as tensorflow already has its own@tf.keras.utils.register_keras_serializable()decorator.load_detectornow allows arbitrarykwargs, which are passed totf.keras.models.load_model(ortorch.load). This is to be used to provide thecustom_objectsat load time (see example below).Example
Example notebook demonstrating serialisation of a detector with a subclassed tensorflow model. Observe how the custom objects must be passed to
load_detectorin order to avoid the errorKeyError: 'layers'.Backwards compatibility
tf.keras.models.load_modelautomatically detects whether a given model path represents ah5model orSavedModel. This means we should be backwards compatible, in that we can simply move to savingSavedModel's, but still support loading of legacyh5models.TODO's
SavedModelformat in docs.- [ ] Add a more involved example of passing custom objects to- More challenging than first envisaged; I wanted to demonstrate on the amazon example, where a subclassedload_detector.ClassifierTFmodel is used, but saving here is not supported due to the use oftokenize_transformer. This is not related to subclassed models so would like to leave for a follow-up PR.SavedModelto support sub-classed models #628 (review)Old notes etc
This notebook contains some experiments run to explore limitations wrt to the
SavedModelformat.